# Copyright 2023 The StableHLO Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@llvm-project//llvm:lit_test.bzl", "lit_test")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@rules_cc//cc:defs.bzl", "cc_library")

package(
    default_visibility = ["//visibility:public"],
    licenses = ["notice"],
)

cc_library(
    name = "check_ops",
    srcs = [
        "CheckOps.cpp",
    ],
    hdrs = [
        "CheckOps.h",
    ],
    strip_include_prefix = ".",
    deps = [
        ":check_ops_inc_gen",
        "//:base",
        "//:reference_element",
        "//:reference_errors",
        "//:reference_numpy",
        "//:reference_tensor",
        "//:reference_token",
        "//:reference_types",
        "//:reference_value",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:BytecodeOpInterface",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:QuantOps",
        "@llvm-project//mlir:Support",
    ],
)

gentbl_cc_library(
    name = "check_ops_inc_gen",
    strip_include_prefix = ".",
    tbl_outs = [
        (
            ["-gen-op-decls"],
            "CheckOps.h.inc",
        ),
        (
            ["-gen-op-defs"],
            "CheckOps.cpp.inc",
        ),
    ],
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "CheckOps.td",
    deps = [
        ":check_ops_td_files",
    ],
)

td_library(
    name = "check_ops_td_files",
    srcs = [
        "CheckOps.td",
    ],
    includes = ["."],
    deps = [
        "//:base_td_files",
    ],
)

cc_library(
    name = "test_utils",
    srcs = [
        "TestUtils.cpp",
    ],
    hdrs = [
        "TestUtils.h",
    ],
    strip_include_prefix = ".",
    deps = [
        ":test_utils_inc_gen",
        "//:stablehlo_assembly_format",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:InferTypeOpInterface",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:Rewrite",
        "@llvm-project//mlir:ShapeDialect",
        "@llvm-project//mlir:SideEffectInterfaces",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TransformUtils",
        "@llvm-project//mlir:Transforms",
    ],
)

gentbl_cc_library(
    name = "test_utils_inc_gen",
    strip_include_prefix = ".",
    tbl_outs = [
        (
            [
                "-gen-pass-decls",
                "-name=HloTest",
            ],
            "TestUtils.h.inc",
        ),
    ],
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "TestUtils.td",
    deps = [
        ":test_utils_td_files",
    ],
)

td_library(
    name = "test_utils_td_files",
    srcs = [
        "TestUtils.td",
    ],
    includes = ["."],
    deps = [
        "@llvm-project//mlir:PassBaseTdFiles",
    ],
)

# Equivalent of configure_lit_site_cfg from CMakeLists.txt.
expand_template(
    name = "lit_site_cfg_py_gen",
    testonly = True,
    out = "lit.site.cfg.py",
    substitutions = {
        "\"@STABLEHLO_SOURCE_DIR@\"": "RUNFILES_DIR",
        "\"@STABLEHLO_TOOLS_DIR@\"": "RUNFILES_DIR",
        "@LIT_SITE_CFG_IN_HEADER@": '''# Autogenerated, do not edit.
from python.runfiles import Runfiles
from pathlib import Path

r = Runfiles.Create()
LITE_CFG_PY = Path(r.Rlocation("stablehlo/stablehlo/tests/lit.cfg.py"))
RUNFILES_DIR = LITE_CFG_PY.parents[2].absolute().as_posix()''',
        "@LLVM_TOOLS_DIR@": "../llvm-project/llvm",
    },
    template = "lit.site.cfg.py.in",
)

# Equivalent of add_lit_testsuite from CMakeLists.txt.
[
    lit_test(
        name = "%s.test" % src,
        size = "small",
        srcs = [src],
        data = [
            "lit.cfg.py",
            "lit.site.cfg.py",
            "//:stablehlo-opt",
            "//:stablehlo-translate",
            "@llvm-project//llvm:FileCheck",
            "@llvm-project//llvm:not",
        ] + glob(["%s.bc" % src]),
        tags = ["stablehlo_tests"],
        deps = ["@rules_python//python/runfiles"],
    )
    for src in glob(["**/*.mlir"])
]

test_suite(
    name = "stablehlo_tests",
    tags = ["stablehlo_tests"],
)
